import torch
import numpy as np
from torch.utils import data
from PIL import Image


class My_dataset(data.Dataset):
    def __init__(self, all_gadf_images, all_gadf_label, transform):
        self.imgs = all_gadf_images
        self.target_label = all_gadf_label
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.imgs[index]
        img = Image.fromarray(img_path)  # 转换为PIL Image
        # 转换为单通道灰度图
        if img.mode != 'L':
            img = img.convert('L')
        tensor_img = self.transform(img)
        return tensor_img, self.target_label[index]

    def __len__(self):
        return len(self.imgs)



